mscJNeuralNet.trainer
Class NetTrainer

java.lang.Object
  extended bymscJNeuralNet.trainer.NetTrainer
All Implemented Interfaces:
observerPattern.Observable, java.lang.Runnable

public class NetTrainer
extends java.lang.Object
implements java.lang.Runnable, observerPattern.Observable

Diese Klasse dient zum Trainieren eines KNN. Mit Hilfe dieser Klasse kann das Trainieren des Netzes als nebenläufiger Prozess realisiert werden. Gleichzeitig stehen nicht-nebenläufige Methoden zur Verfügung, um ein Netz bis zu einem bestimmten Fehlerwert oder einem bestimmten Lernzyklus zu trainieren.

Nicht-Nebenläufiges Trainieren

Über die statischen Methoden dieser Klasse kann einfach Trainiert werden:
  • train(INetTrainingAlgorithm, double[][], double[][], int): Trainieren bis eine bestimmte Anzahl an Lernschritten absolviert wurde.
  • train(INetTrainingAlgorithm, double[][], double[][], double, int): Trainieren bis zu einem bestimmten Fehlerwert. Es werden alle Haupt-Fehlertypen aus NetPerformanceStatistics unterstützt.
  • train(INetTrainingAlgorithm, double[][], double[][], double, int, int): Trainieren bis zu einem bestimmten Fehlerwert oder bis eine bestimmte Anzahl an Lernschritten absolviert wurde.
  • Nebenläufiges Trainieren

    Das nebenläufige Trainieren erlaubt es, das Training in einem eigenen Prozess ablaufen zu lassen. Das Training kann entweder durch Aufruf der Methode stop() oder durch erreichen bestimmter Bedingungen, wie dem Unterschreiten eines Fehlerwertes oder das Absolvieren einer bestimmten Anzahl an Lernschritten angehalten werden. Das Training kann durch die Schnittstelle Observable überwacht werden, denn der NetTrainer benachrichtigt nach jedem Lernschritt alle gemeldeten Observer. Im folgenden ein einfaches Beispiel, dass das dreidimensionale Xor-Problem als nebenläufigen Prozess trainiert:

    package mscJNeuralNet.examples;
    
    import observerPattern.Observable;
    import observerPattern.Observer;
    import mscJNeuralNet.connectors.INetConnector;
    import mscJNeuralNet.connectors.RandomSymmetryBreakingNetConnector;
    import mscJNeuralNet.net.Net;
    import mscJNeuralNet.net.PatternDoesNotMatchNetException;
    import mscJNeuralNet.netPerformanceStatistics.NetPerformanceReporter;
    import mscJNeuralNet.netPerformanceStatistics.NetPerformanceStatistics;
    import mscJNeuralNet.patterns.Patterns;
    import mscJNeuralNet.trainer.NetTrainer;
    import mscJNeuralNet.trainingAlgorithms.INetTrainingAlgorithm;
    import mscJNeuralNet.trainingAlgorithms.RProp;
    
    public class TestTrippleXorLearnConcurrent implements Observer{
    
      public TestTrippleXorLearnConcurrent() 
        throws PatternDoesNotMatchNetException{
    
        // 1. Erzeugen der benötigten Klassen
    
        // MLP Net
        // MLP mit der gewünschten Schichtstruktur erstellen:
        // Eingabeschicht: 3 Neuronen
        // 1. Verdeckte Schicht: 3 Neuronen
        // Ausgabeschicht: 1 Neuron
        int [] layerSizesTrippleXOr = {3, 3, 1};
        // Net myNet = new Net(layerSizesTrippleXOr);
        Net myNet = new Net(layerSizesTrippleXOr);
    
        // BIAS wurde automatisch berücksichtigt.
    
        // INetConnector
        // Diese Klasse wird zum Initialisieren der Kantengewichte benötigt
        INetConnector lNetConnectionAlgo = new RandomSymmetryBreakingNetConnector();
    
        // INetTrainingAlgorithm
        // Diese Klasse wird zum Trainieren eines MLP benötigt
        INetTrainingAlgorithm lNetTrainAlgo = new RProp();
    
        // NetTrainer
        // Diese Klasse wird das Netz als nebenläufigen Prozess trainieren.
        NetTrainer lTrainer = new NetTrainer();
    
    
        // 2. Initialisieren des Netzes
    
        // MLP + INetConnector
        // Mit der Instanz von INetConnector inititalisieren.
        // Die Klasse RandomSymmetryBreakingNetConnector benötigt
        // keine eigenen Kontrollparameter und wird daher mit dem Wert null aufgerufen.
        lNetConnectionAlgo.connectNet(myNet, null);
    
        // Nun ist das Netz verbunden und initialisiert.
    
        // INetTrainingAlgorithm
        // Lernverfahren INetTrainingAlgorithm mit Netz verbinden.
        lNetTrainAlgo.setNet(myNet);
    
    
        // 3. Trainer konfigurieren
    
        // Dem Trainer das Lernverfahren und das MLP mitteilen 
        lTrainer.setTrainingAlgorithm(lNetTrainAlgo);
        // 100 Lernschritte lang trainieren
        lTrainer.setTargetCycles(100);
        // oder bis Fehlerwert unter 0.01
        lTrainer.setTargetError(0.01);
        // Fehlerwert soll vom Typ SSE sein
        lTrainer.setTargetErrorType(
          NetPerformanceStatistics.ERRORTYPE_averageSumOfSquaredError);
        // Lerndaten immer in derselben Reihenfolge präsentieren
        lTrainer.setUseRandomizedPatternOrder(false);
    
        // Dieses Programm als Observer des Trainers anmelden
        lTrainer.getObserverManager().addObserver(this);
    
    
        // 4. Trainieren des Netzes
    
        // TRAININGSDATEN
        // Trainingsdaten bereitstellen:
        // Trainingsdaten in ein Patterns Objekt übertragen
        // Komforbaler ist es, die Lerndaten in einer Textdatei 
        // zu speichern und hier zu laden
        Patterns lTrippleXorPat = new Patterns(3, 1, 8);
        // {-1,-1,-1}
        lTrippleXorPat.setInputToken(0, 0, -1);
        lTrippleXorPat.setInputToken(1, 0, -1);
        lTrippleXorPat.setInputToken(2, 0, -1);
        // {-1}
        lTrippleXorPat.setOutputToken(0, 0, -1);
    
        // {-1,-1,1}
        lTrippleXorPat.setInputToken(0, 1, -1);
        lTrippleXorPat.setInputToken(1, 1, -1);
        lTrippleXorPat.setInputToken(2, 1, 1);
        // {1}
        lTrippleXorPat.setOutputToken(0, 1, 1);
    
        // {-1,1,-1}
        lTrippleXorPat.setInputToken(0, 2, -1);
        lTrippleXorPat.setInputToken(1, 2, 1);
        lTrippleXorPat.setInputToken(2, 2, -1);
        // {1}
        lTrippleXorPat.setOutputToken(0, 2, 1);
    
        // {-1,1,1}
        lTrippleXorPat.setInputToken(0, 3, -1);
        lTrippleXorPat.setInputToken(1, 3, 1);
        lTrippleXorPat.setInputToken(2, 3, 1);
        // {-1}
        lTrippleXorPat.setOutputToken(0, 3, -1);
    
        // {1,-1,-1}
        lTrippleXorPat.setInputToken(0, 4, 1);
        lTrippleXorPat.setInputToken(1, 4, -1);
        lTrippleXorPat.setInputToken(2, 4, -1);
        // {1}
        lTrippleXorPat.setOutputToken(0, 4, 1);
    
        // {1,-1,1}
        lTrippleXorPat.setInputToken(0, 5, 1);
        lTrippleXorPat.setInputToken(1, 5, -1);
        lTrippleXorPat.setInputToken(2, 5, 1);
        // {-1}
        lTrippleXorPat.setOutputToken(0, 5, -1);
    
        // {1,1,-1}
        lTrippleXorPat.setInputToken(0, 6, 1);
        lTrippleXorPat.setInputToken(1, 6, 1);
        lTrippleXorPat.setInputToken(2, 6, -1);
        // {-1}
        lTrippleXorPat.setOutputToken(0, 6, -1);
    
        // {1,1,1}
        lTrippleXorPat.setInputToken(0, 7, 1);
        lTrippleXorPat.setInputToken(1, 7, 1);
        lTrippleXorPat.setInputToken(2, 7, 1);
        // {1}
        lTrippleXorPat.setOutputToken(0, 7, 1);
    
        // Trainingsdatenmenge dem Trainer mitteilen
        lTrainer.setTrainingPatterns(lTrippleXorPat);
    
        // Training beginnen
        lTrainer.start();
      }
    
      public static void main(String [] args) {
        try{
          new TestTrippleXorLearnConcurrent();
        }
        catch (PatternDoesNotMatchNetException e){
          e.printStackTrace();
        }
      }
    
      public void notify(Observable pObservable) {
        // Diese Methode wird vom NetTrainer nach jedem 
        // Lernschritt aufgerufen.
        // Das übergebene Observable Objekt ist der NetTrainer selbst
        if (pObservable instanceof NetTrainer){
          NetTrainer lTrainer = (NetTrainer) pObservable;
          // Prüfen, ob das Training vorbei ist
          if (lTrainer.hasFinished()){
            System.out.println("Training fertig.");
            System.out.println(
              NetPerformanceReporter.getNetPerformance(
                lTrainer.getLastCalcualtedNetStatistics(), 
                lTrainer.getTrainingAlgorithm().getCycle()
              )
            );
            System.exit(0);
          }
          else{
            // Sonst alle 10 Lernschritte Status ausgeben
            if (lTrainer.getTrainingAlgorithm().getCycle() % 10 == 0){
              System.out.println("Lernschritt: "+
                lTrainer.getTrainingAlgorithm().getCycle());
              System.out.println("Aktueller Netzfehler: "+
                lTrainer.getTargetErrorTypeString()+" "+
                lTrainer.getLastCalculatedError());
            }
          }
        }
      }
    }
     

    Created on 05.06.2004

    Version:
    26.06.2004
    Author:
    M. Serhat Cinar
    See Also:
    Net

    Constructor Summary
    NetTrainer()
               
     
    Method Summary
     NetPerformanceStatistics getLastCalcualtedNetStatistics()
              Liefert die akteullen Fehlerwerte des Netzes.
     double getLastCalculatedError()
              Liefert den Fehlerwert aus dem letzten Lernschritt.
     observerPattern.ObserverManager getObserverManager()
               
     int getTargetCycles()
              Liefert die Anzahl der Lernschritte, die der Trainer absolvieren soll.
     double getTargetError()
              Liefert den Fehlerwert, dessen Unterschreiten das Training beendet.
     int getTargetErrorType()
              Liefert den Typ des Fehlerwertes aus setTargetError(double).
     java.lang.String getTargetErrorTypeString()
              Liefert die Stringrepräsentation des aktuellen Fehlertyps zurück.
     INetTrainingAlgorithm getTrainingAlgorithm()
              Liefert das aktuell benutzte Lernverfahren.
     Patterns getTrainingPatterns()
              Liefert die lernenden Lerndatensätze zurück.
     boolean hasFinished()
               
     boolean isRunning()
              Prüft, ob der nebenläufige Prozess dieses Trainers aktiv ist.
     boolean isUsingRandomizedPatternOrder()
              Testet, ob der Trainer zum Trainieren die Lerndatensätze in gegebener oder in zufälliger Reihenfolge in jedem Lernschritt benutzt.
     void run()
               
     void setTargetCycles(int pTargetCycles)
              Legt fest, wieviele Lernschritte der Trainer absolvieren soll.
     void setTargetError(double pTargetError)
              Legt fest, bei welchem Fehlerwert der Trainer das Training als erfolgreich beenden soll.
     void setTargetErrorType(int pTargetErrorType)
              Typ des Fehlers, dessen Unterschreitung das Training beendet setTargetError(double).
     void setTrainingAlgorithm(INetTrainingAlgorithm pTrainingAlgorithm)
              Legt das Lernverfahren für das Training fest.
     void setTrainingPatterns(Patterns pTrainingPatterns)
              Legt die Lerndatenmenge fest.
     void setUseRandomizedPatternOrder(boolean pUseRandomizedPatternOrder)
              Legt fest, ob beim Training die Lerndatensätze immer in der gleichen Reihenfolge gelernt werden sollen, oder ob nach jedem Lernschritt die Reihenfolge zufällig ermittelt werden soll.
     void start()
              Startet den nebenläufigen Prozess zum Trainiern des Netzes.
     void stop()
              Hält den nebenläufigen Prozess zum Trainieren des Netzes an.
    static void train(INetTrainingAlgorithm pTrainingAlgorithm, double[][] pInputPatterns, double[][] pOutputPatterns, double pTargetError, int pErrorType)
              Trainieren bis zu einem bestimmten Fehlerwert.
    static void train(INetTrainingAlgorithm pTrainingAlgorithm, double[][] pInputPatterns, double[][] pOutputPatterns, double pTargetError, int pErrorType, int pCycles)
              Trainieren bis zu einem bestimmten Fehlerwert oder bis eine bestimmte Anzahl an Lernschritten absolviert wurde.
    static void train(INetTrainingAlgorithm pTrainingAlgorithm, double[][] pInputPatterns, double[][] pOutputPatterns, int pCycles)
              Trainieren bis die gegebene Anzahl an Lernschritten absolviert wurde.
     
    Methods inherited from class java.lang.Object
    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
     

    Constructor Detail

    NetTrainer

    public NetTrainer()
    Method Detail

    setTrainingAlgorithm

    public void setTrainingAlgorithm(INetTrainingAlgorithm pTrainingAlgorithm)
                              throws PatternDoesNotMatchNetException
    Legt das Lernverfahren für das Training fest. Es wird das Netz, welches im Lernverfahren registriert ist, benutzt. Am Ende der Methode werden die Observer des NetTrainer benachrichtigt.

    Parameters:
    pTrainingAlgorithm - Das zu benutzende Lernverfahren und das enthaltene Netz.
    Throws:
    PatternDoesNotMatchNetException - Falls eines der gegebenen Lerndatensätze nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.

    getTrainingAlgorithm

    public INetTrainingAlgorithm getTrainingAlgorithm()
    Liefert das aktuell benutzte Lernverfahren.

    Returns:
    Das aktuell benutzte Lernverfahren.

    setTrainingPatterns

    public void setTrainingPatterns(Patterns pTrainingPatterns)
                             throws PatternDoesNotMatchNetException
    Legt die Lerndatenmenge fest. Das übergebene Patterns Objekt sollte alle Lerndaten enthalten, die dem Netz trainiert werden sollen.

    Parameters:
    pTrainingPatterns - Die zu lernende Lerndatensätze.
    Throws:
    PatternDoesNotMatchNetException - Falls eines der gegebenen Lerndatensätze nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.

    getTrainingPatterns

    public Patterns getTrainingPatterns()
    Liefert die lernenden Lerndatensätze zurück.

    Returns:
    Die lernenden Lerndatensätze zurück.

    setTargetCycles

    public void setTargetCycles(int pTargetCycles)
    Legt fest, wieviele Lernschritte der Trainer absolvieren soll. Das nebenläufige Training wird durchgeführt, bis entweder die angegebene Anzahl Lernschritte absolviert wurde oder der durch setTargetError(double) definierte Fehlerwert unterschritten wurde.

    Parameters:
    pTargetCycles - Anzahl der Lernschritte, die der Trainer absolvieren soll.
    See Also:
    setTargetError(double)

    getTargetCycles

    public int getTargetCycles()
    Liefert die Anzahl der Lernschritte, die der Trainer absolvieren soll.

    Returns:
    Die Anzahl der Lernschritte, die der Trainer absolvieren soll.

    setTargetError

    public void setTargetError(double pTargetError)
    Legt fest, bei welchem Fehlerwert der Trainer das Training als erfolgreich beenden soll. Das nebenläufige Training wird durchgeführt, bis entweder die durch setTargetCycles(int) angegebene Anzahl Lernschritte absolviert wurden oder der definierte Fehlerwert unterschritten wurde.
    Der Typ des Fehlerwertes kann jeder beliebige Hauptfehlertyp aus der Klasse NetPerformanceStatistics sein.

    Parameters:
    pTargetError - Fehlerwert, dessen Unterschreiten das Training beendet.
    See Also:
    setTargetCycles(int), setTargetErrorType(int)

    getTargetError

    public double getTargetError()
    Liefert den Fehlerwert, dessen Unterschreiten das Training beendet.

    Returns:
    Fehlerwert, dessen Unterschreiten das Training beendet.

    setTargetErrorType

    public void setTargetErrorType(int pTargetErrorType)
    Typ des Fehlers, dessen Unterschreitung das Training beendet setTargetError(double). Als Fehlertyp kann jeder beliebige Hauptfehlertyp aus der Klasse NetPerformanceStatistics benutzt werden.

    Parameters:
    pTargetErrorType - Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
    See Also:
    setTargetError(double), getTargetErrorTypeString()

    getTargetErrorType

    public int getTargetErrorType()
    Liefert den Typ des Fehlerwertes aus setTargetError(double). Der zurückgelieferte Wert entspricht einer Konstante für den Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics

    Returns:
    Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
    See Also:
    getTargetErrorTypeString()

    getTargetErrorTypeString

    public java.lang.String getTargetErrorTypeString()
    Liefert die Stringrepräsentation des aktuellen Fehlertyps zurück.

    Returns:
    Die Stringrepräsentation des aktuellen Fehlertyps.

    isUsingRandomizedPatternOrder

    public boolean isUsingRandomizedPatternOrder()
    Testet, ob der Trainer zum Trainieren die Lerndatensätze in gegebener oder in zufälliger Reihenfolge in jedem Lernschritt benutzt.

    Returns:
    true, falls eine zufällige Reihenfolge benutzt wird, false, falls immer die gleiche Reihenfolge benutzt wird.
    See Also:
    setUseRandomizedPatternOrder(boolean)

    setUseRandomizedPatternOrder

    public void setUseRandomizedPatternOrder(boolean pUseRandomizedPatternOrder)
    Legt fest, ob beim Training die Lerndatensätze immer in der gleichen Reihenfolge gelernt werden sollen, oder ob nach jedem Lernschritt die Reihenfolge zufällig ermittelt werden soll.
    Eine zufällige Präsentationsreihenfolge der Lerndatensätze erhöht die Generalisierungsfähigkeit des Netzes, was aber geichzeitig bedeutet, dass der Netzfehler für die Lerndatenmenge ebenfalls größer wird. (Generalisierung vs. Spezialisierung auf die Lerndatensätze).

    Parameters:
    pUseRandomizedPatternOrder - true = zufällige Reihenfolge benutzen, false = immer die gleiche Reihenfolge benutzen.
    See Also:
    isUsingRandomizedPatternOrder(), Patterns.getRandomizedPatternsOrder()

    getLastCalculatedError

    public double getLastCalculatedError()
    Liefert den Fehlerwert aus dem letzten Lernschritt. Der Typ des Fehlerwertes kann durch die Methode getTargetErrorType() ermittelt werden.

    Returns:
    Aktueller Fehlerwert des Netzes.
    See Also:
    getTargetErrorType(), getLastCalcualtedNetStatistics()

    getLastCalcualtedNetStatistics

    public NetPerformanceStatistics getLastCalcualtedNetStatistics()
    Liefert die akteullen Fehlerwerte des Netzes.

    Returns:
    Aktuelle Fehlerwerte des Netzes.

    hasFinished

    public boolean hasFinished()

    start

    public void start()
    Startet den nebenläufigen Prozess zum Trainiern des Netzes. Vorher muss dem Trainer über setTrainingAlgorithm(INetTrainingAlgorithm) ein Lernverfahren mit Netz und über setTrainingPatterns(Patterns) Lerndatensätze zugeteilt werden. Zusätzlich sollte und Aufruf von setTargetCycles(int), setTargetError(double), setTargetErrorType(int) das gewünschte Lernziel festgelegt werden.


    stop

    public void stop()
    Hält den nebenläufigen Prozess zum Trainieren des Netzes an.


    isRunning

    public boolean isRunning()
    Prüft, ob der nebenläufige Prozess dieses Trainers aktiv ist.

    Returns:
    true, falls dieser Trainer gerade aktiv ist, sonst false.

    run

    public void run()
    Specified by:
    run in interface java.lang.Runnable

    getObserverManager

    public observerPattern.ObserverManager getObserverManager()
    Specified by:
    getObserverManager in interface observerPattern.Observable

    train

    public static void train(INetTrainingAlgorithm pTrainingAlgorithm,
                             double[][] pInputPatterns,
                             double[][] pOutputPatterns,
                             int pCycles)
                      throws PatternDoesNotMatchNetException
    Trainieren bis die gegebene Anzahl an Lernschritten absolviert wurde.

    Parameters:
    pTrainingAlgorithm - Lernverfahren, das beim Training bernutzt werden soll. Das Lernverfahren enthält auch die Referenz zum Netz, das trainiert werden soll.
    pInputPatterns - Eingabemuster der Lerndatensätze.
    pOutputPatterns - Ausgabemuster der Lerndatensätze (Soll-Werte).
    pCycles - Anzahl der Lernschritte, die absolviert werden sollen.
    Throws:
    PatternDoesNotMatchNetException - Falls eines der gegebenen Lerndatensätze nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.

    train

    public static void train(INetTrainingAlgorithm pTrainingAlgorithm,
                             double[][] pInputPatterns,
                             double[][] pOutputPatterns,
                             double pTargetError,
                             int pErrorType)
                      throws PatternDoesNotMatchNetException
    Trainieren bis zu einem bestimmten Fehlerwert. Es werden alle Haupt-Fehlertypen aus NetPerformanceStatistics unterstützt.

    Parameters:
    pTrainingAlgorithm - Lernverfahren, das beim Training bernutzt werden soll. Das Lernverfahren enthält auch die Referenz zum Netz, das trainiert werden soll.
    pInputPatterns - Eingabemuster der Lerndatensätze.
    pOutputPatterns - Ausgabemuster der Lerndatensätze (Soll-Werte).
    pTargetError - Zielwert des Fehlers, bei dem das Training gestoppt werden soll.
    pErrorType - Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
    Throws:
    PatternDoesNotMatchNetException - Falls eines der gegebenen Lerndatensätze nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.
    See Also:
    NetPerformanceStatistics

    train

    public static void train(INetTrainingAlgorithm pTrainingAlgorithm,
                             double[][] pInputPatterns,
                             double[][] pOutputPatterns,
                             double pTargetError,
                             int pErrorType,
                             int pCycles)
                      throws PatternDoesNotMatchNetException
    Trainieren bis zu einem bestimmten Fehlerwert oder bis eine bestimmte Anzahl an Lernschritten absolviert wurde. Es werden alle Haupt-Fehlertypen aus NetPerformanceStatistics unterstützt.
    Je nachdem, welche der beiden Bedingungen, der minimale Fehlerwert oder die maximale Anzahl der absolvierten Lernschritte, zuerst erreicht wird, beendet das Training.

    Parameters:
    pTrainingAlgorithm - Lernverfahren, das beim Training bernutzt werden soll. Das Lernverfahren enthält auch die Referenz zum Netz, das trainiert werden soll.
    pInputPatterns - Eingabemuster der Lerndatensätze.
    pOutputPatterns - Ausgabemuster der Lerndatensätze (Soll-Werte).
    pTargetError - Zielwert des Fehlers, bei dem das Training gestoppt werden soll
    pErrorType - Fehlertyp (ERRORTYPE-Konstante) aus der Klasse NetPerformanceStatistics
    pCycles - Anzahl der Lernschritte, die absolviert werden sollen.
    Throws:
    PatternDoesNotMatchNetException - Falls eines der gegebenen Lerndatensätze nicht der Eingabe-/ Ausgabeschichtgröße des Netzes entspricht.
    See Also:
    NetPerformanceStatistics